{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "# Offline evaluation of Pose Estimation models using PyCocoTools\n",
    "\n",
    "This example shows how you can evaluate pose estimation models using PyCocoEval class from pycocotools python package.\n",
    "Although we provide a ready-to-use metric class to compute average precision (AP) and average recall (AR) scores, the \n",
    "evaluation protocol during validation is slightly different from what pycocotools suggests for academic evaluation.\n",
    "\n",
    "In particular:\n",
    "\n",
    "## SG\n",
    "\n",
    "* In SG, during training/validation, we resize all images to a fixed size (Default is 640x640) using aspect-ratio preserving resize of the longest size + padding. \n",
    "* Our metric evaluate AP/AR in the resolution of the resized & padded images, **not in the resolution of original image**. \n",
    "\n",
    "\n",
    "## COCOEval\n",
    "\n",
    "* In COCOEval all images are not resized and pose predictions are evaluated in the resolution of original image \n",
    "\n",
    "Because of this discrepancy, metrics reported by `PoseEstimationMetrics` class is usually a bit lower (Usually by ~1AP) than the ones \n",
    "you would get from the same model if computed with COCOEval. \n",
    "\n",
    "For this reason we provide this example to show how you can compute metrics using COCOEval for pose estimation models that are available in SuperGradients."
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "f36fcd2e6daa861c"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "!pip install -qq super_gradients==3.7.1"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "319deb6e3293a2c9"
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Instantiate the model for evaluation\n",
    "\n",
    "First, let's instantiate the model we are going to evaluate. \n",
    "You can use either pretrained models or provide a checkpoint path to your own trained checkpoint.\n",
    "\n",
    "```python\n",
    "# This is how you can load your custom checkpoint instead of pretrained one\n",
    "model = models.get(\n",
    "    Models.YOLO_NAS_POSE_L,\n",
    "    num_classes=17,\n",
    "    checkpoint_path=\"/Absolute/Path/To/Your/Custom/Checkpoint.pth\",\n",
    ")\n",
    "```\n",
    "In this example we will be using pretrained weights for simplicity."
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "33367b7bd09fccbb"
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "outputs": [],
   "source": [
    "from super_gradients.common.object_names import Models\n",
    "from super_gradients.training import models\n",
    "\n",
    "model = models.get(\n",
    "    Models.YOLO_NAS_POSE_L,\n",
    "    pretrained_weights=\"coco_pose\"\n",
    ").cuda()\n"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-11-03T10:03:53.294954700Z",
     "start_time": "2023-11-03T10:03:51.750726600Z"
    }
   },
   "id": "36c9acfa4c5057c"
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Prepare COCO validation data\n",
    "\n",
    "Next, we obtain list of images in COCO2017 validation set and load their annotations.\n",
    "You may want to either set the COCO_ROOT_DIR environment variable where COCO2017 data is located on your machine or edit the default path directylu"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "372d6ce9605c68f9"
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "outputs": [
    {
     "data": {
      "text/plain": "['annotations', 'images']"
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import os\n",
    "COCO_DATA_DIR = os.environ.get(\"COCO_ROOT_DIR\", \"g:/coco2017\")\n",
    "os.listdir(COCO_DATA_DIR)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-11-03T10:03:53.325600100Z",
     "start_time": "2023-11-03T10:03:53.298993900Z"
    }
   },
   "id": "91e2ee3f26130e1a"
  },
  {
   "cell_type": "markdown",
   "source": [
    "Once data is set we can load it"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "9a8be6a43b014149"
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "outputs": [],
   "source": [
    "from pycocotools.cocoeval import COCOeval"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-11-03T10:03:53.361646400Z",
     "start_time": "2023-11-03T10:03:53.312545300Z"
    }
   },
   "id": "ca413a0cfd65d3d9"
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "outputs": [],
   "source": [
    "from pycocotools.coco import COCO\n",
    "\n",
    "images_path = os.path.join(COCO_DATA_DIR, \"images/val2017\")\n",
    "image_files = [os.path.join(images_path, x) for x in os.listdir(images_path)]\n",
    "\n",
    "gt_annotations_path = os.path.join(COCO_DATA_DIR, \"annotations/person_keypoints_val2017.json\")\n",
    "gt = COCO(gt_annotations_path)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-11-03T10:03:54.397071500Z",
     "start_time": "2023-11-03T10:03:53.330346300Z"
    }
   },
   "id": "31ef5c6c808fc5c2"
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Predicting Images:  99%|█████████▉| 4961/5000 [01:53<00:00, 47.22it/s]"
     ]
    }
   ],
   "source": [
    "predictions = model.predict(\n",
    "    image_files, conf=0.01, iou=0.7, pre_nms_max_predictions=300, post_nms_max_predictions=20, fuse_model=False\n",
    ")"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-11-03T10:06:38.467559300Z",
     "start_time": "2023-11-03T10:03:54.380901900Z"
    }
   },
   "id": "7d20871ce397f6cb"
  },
  {
   "cell_type": "markdown",
   "source": [
    "At this point we have all the predictions are completed, and what is left is to convert our predictions to the format required by pycocotools and \n",
    "send them to COCOEval class to compute final metrics "
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "cdc7518861546b18"
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "outputs": [],
   "source": [
    "import copy\n",
    "import json_tricks as json\n",
    "import collections\n",
    "import numpy as np\n",
    "import tempfile\n",
    "\n",
    "def predictions_to_coco(predictions, image_files):\n",
    "    predicted_poses = []\n",
    "    predicted_scores = []\n",
    "    non_empty_image_ids = []\n",
    "    for image_file, image_predictions in zip(image_files, predictions):\n",
    "        non_empty_image_ids.append(int(os.path.splitext(os.path.basename(image_file))[0]))\n",
    "        predicted_poses.append(image_predictions.prediction.poses)\n",
    "        predicted_scores.append(image_predictions.prediction.scores)\n",
    "\n",
    "    coco_pred = _coco_convert_predictions_to_dict(predicted_poses, predicted_scores, non_empty_image_ids)\n",
    "    return coco_pred\n",
    "\n",
    "def _coco_process_keypoints(keypoints):\n",
    "    tmp = keypoints.copy()\n",
    "    if keypoints[:, 2].max() > 0:\n",
    "        num_keypoints = keypoints.shape[0]\n",
    "        for i in range(num_keypoints):\n",
    "            tmp[i][0:3] = [float(keypoints[i][0]), float(keypoints[i][1]), float(keypoints[i][2])]\n",
    "\n",
    "    return tmp\n",
    "\n",
    "def _coco_convert_predictions_to_dict(predicted_poses, predicted_scores, image_ids):\n",
    "    kpts = collections.defaultdict(list)\n",
    "    for poses, scores, image_id_int in zip(predicted_poses, predicted_scores, image_ids):\n",
    "\n",
    "        for person_index, kpt in enumerate(poses):\n",
    "            area = (np.max(kpt[:, 0]) - np.min(kpt[:, 0])) * (np.max(kpt[:, 1]) - np.min(kpt[:, 1]))\n",
    "            kpt = _coco_process_keypoints(kpt)\n",
    "            kpts[image_id_int].append({\"keypoints\": kpt[:, 0:3], \"score\": float(scores[person_index]), \"image\": image_id_int, \"area\": area})\n",
    "\n",
    "    oks_nmsed_kpts = []\n",
    "    # image x person x (keypoints)\n",
    "    for img in kpts.keys():\n",
    "        # person x (keypoints)\n",
    "        img_kpts = kpts[img]\n",
    "        # person x (keypoints)\n",
    "        # do not use nms, keep all detections\n",
    "        keep = []\n",
    "        if len(keep) == 0:\n",
    "            oks_nmsed_kpts.append(img_kpts)\n",
    "        else:\n",
    "            oks_nmsed_kpts.append([img_kpts[_keep] for _keep in keep])\n",
    "\n",
    "    classes = [\"__background__\", \"person\"]\n",
    "    _class_to_coco_ind = {cls: i for i, cls in enumerate(classes)}\n",
    "\n",
    "    data_pack = [\n",
    "        {\"cat_id\": _class_to_coco_ind[cls], \"cls_ind\": cls_ind, \"cls\": cls, \"ann_type\": \"keypoints\", \"keypoints\": oks_nmsed_kpts}\n",
    "        for cls_ind, cls in enumerate(classes)\n",
    "        if not cls == \"__background__\"\n",
    "    ]\n",
    "\n",
    "    results = _coco_keypoint_results_one_category_kernel(data_pack[0], num_joints=17)\n",
    "    return results\n",
    "\n",
    "def _coco_keypoint_results_one_category_kernel(data_pack, num_joints: int):\n",
    "    cat_id = data_pack[\"cat_id\"]\n",
    "    keypoints = data_pack[\"keypoints\"]\n",
    "    cat_results = []\n",
    "\n",
    "    for img_kpts in keypoints:\n",
    "        if len(img_kpts) == 0:\n",
    "            continue\n",
    "\n",
    "        _key_points = np.array([img_kpts[k][\"keypoints\"] for k in range(len(img_kpts))])\n",
    "        key_points = np.zeros((_key_points.shape[0], num_joints * 3), dtype=np.float32)\n",
    "\n",
    "        for ipt in range(num_joints):\n",
    "            key_points[:, ipt * 3 + 0] = _key_points[:, ipt, 0]\n",
    "            key_points[:, ipt * 3 + 1] = _key_points[:, ipt, 1]\n",
    "            # keypoints score.\n",
    "            key_points[:, ipt * 3 + 2] = _key_points[:, ipt, 2]\n",
    "\n",
    "        for k in range(len(img_kpts)):\n",
    "            kpt = key_points[k].reshape((num_joints, 3))\n",
    "            left_top = np.amin(kpt, axis=0)\n",
    "            right_bottom = np.amax(kpt, axis=0)\n",
    "\n",
    "            w = right_bottom[0] - left_top[0]\n",
    "            h = right_bottom[1] - left_top[1]\n",
    "\n",
    "            cat_results.append(\n",
    "                {\n",
    "                    \"image_id\": img_kpts[k][\"image\"],\n",
    "                    \"category_id\": cat_id,\n",
    "                    \"keypoints\": list(key_points[k]),\n",
    "                    \"score\": img_kpts[k][\"score\"],\n",
    "                    \"bbox\": list([left_top[0], left_top[1], w, h]),\n",
    "                }\n",
    "            )\n",
    "\n",
    "    return cat_results\n",
    "\n",
    "coco_pred = predictions_to_coco(predictions, image_files)\n",
    "\n",
    "with tempfile.TemporaryDirectory() as td:\n",
    "    res_file = os.path.join(td, \"keypoints_coco2017_results.json\")\n",
    "\n",
    "    with open(res_file, \"w\") as f:\n",
    "        json.dump(coco_pred, f)\n",
    "\n",
    "    coco_dt = copy.deepcopy(gt)\n",
    "    coco_dt = coco_dt.loadRes(res_file)\n",
    "\n",
    "    coco_evaluator = COCOeval(gt, coco_dt, iouType=\"keypoints\")\n",
    "    coco_evaluator.evaluate()  # run per image evaluation\n",
    "    coco_evaluator.accumulate()  # accumulate per image results\n",
    "    coco_evaluator.summarize()  # display summary metrics of results"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-11-03T10:11:52.947199100Z",
     "start_time": "2023-11-03T10:09:13.016616200Z"
    }
   },
   "id": "a077a3e5d1ed0342"
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "outputs": [
    {
     "data": {
      "text/plain": "array([0.68241641, 0.89067774, 0.75156452, 0.63089194, 0.76596845,\n       0.73526448, 0.92411839, 0.79927582, 0.68254575, 0.81062802])"
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "coco_evaluator.stats"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-11-03T10:11:52.956433500Z",
     "start_time": "2023-11-03T10:11:52.953735700Z"
    }
   },
   "id": "bc3ab706ff4eeb9"
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "outputs": [
    {
     "data": {
      "text/plain": "('AP @0.5..0.95', 0.6824164138074639)"
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\"AP @0.5..0.95\", coco_evaluator.stats[0]"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-11-03T10:12:17.645470100Z",
     "start_time": "2023-11-03T10:12:17.615005500Z"
    }
   },
   "id": "7f1f02de114a3ed4"
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "outputs": [
    {
     "data": {
      "text/plain": "('AP @ 0.5', 0.8906777403136868)"
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\"AP @ 0.5\", coco_evaluator.stats[1]"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-11-03T10:12:17.719299200Z",
     "start_time": "2023-11-03T10:12:17.625340900Z"
    }
   },
   "id": "33d5fa9eb8d719de"
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "outputs": [
    {
     "data": {
      "text/plain": "('AP @ 0.75', 0.7515645236306682)"
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\"AP @ 0.75\", coco_evaluator.stats[2]"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-11-03T10:12:17.746369400Z",
     "start_time": "2023-11-03T10:12:17.644448900Z"
    }
   },
   "id": "8d4af21fbf550190"
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "outputs": [
    {
     "data": {
      "text/plain": "('AR @0.5..0.95', 0.7352644836272041)"
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\"AR @0.5..0.95\", coco_evaluator.stats[5]"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-11-03T10:12:17.746369400Z",
     "start_time": "2023-11-03T10:12:17.656524900Z"
    }
   },
   "id": "d130f7ccdf406c53"
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "outputs": [
    {
     "data": {
      "text/plain": "('AR @0.5', 0.9241183879093199)"
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\"AR @0.5\", coco_evaluator.stats[6]"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-11-03T10:12:17.746369400Z",
     "start_time": "2023-11-03T10:12:17.675061800Z"
    }
   },
   "id": "58cbffec33dadf77"
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "outputs": [
    {
     "data": {
      "text/plain": "('AR @0.75', 0.7992758186397985)"
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\"AR @0.75\", coco_evaluator.stats[7]"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-11-03T10:12:17.746369400Z",
     "start_time": "2023-11-03T10:12:17.687096600Z"
    }
   },
   "id": "6ad8f57c80ce723b"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false
   },
   "id": "ddb3c6c669aae5d"
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
